#%%
import moviepy.editor as mpy
import matplotlib.pyplot as plt, numpy as np, seaborn as sns, tqdm, pathlib, pims, cv2 as cv, xarray as xf, pandas as pd, scipy.ndimage as ndimage


# Check pixel size
# vid = pims.Video('20220210_ruler_start.mp4')
# vid[2500][195:230,558:592]
pix2mm_start = 10/(592-558)

# vid = pims.Video('20220210_ruler_end.mp4')
# vid[2500][195:230,557:590]
pix2mm_end = 10/(590-557)

pix2mm = np.mean([pix2mm_start, pix2mm_end])
print(pix2mm)

#%%
vidname = '20220210_250fps_michaels-gravel_4'
vid = pims.Video('%s.mp4'%vidname)
name11 = 'mk_4'
# img = vid[1440][:,150:-150]
# cv.imwrite('poster_pic2.png', img)
start = 0
nframes = vid._len
# start = 2000
# nframes = 2

cent = np.zeros((2,nframes))
dt = 1/250

#%%
for i in tqdm.tqdm(range(start,start+nframes)):
    img = vid[i]
    img = img[:,150:900]
    img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    img = cv.medianBlur(img, 7)
    img = cv.adaptiveThreshold(img, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 21, -11)
    img = cv.medianBlur(img, 5)
    kernel = np.ones((5,5),np.uint8)
    img = cv.morphologyEx(img, cv.MORPH_CLOSE, kernel, iterations=5)
    img[img>0] = 1
    img = img[10:-10,:]
    # plt.imshow(img)
    if img.max() > 0:
        cent[:,i-start] = ndimage.measurements.center_of_mass(img)
    else:
        cent[:,i-start] = np.nan

#%%
y_pos = (-cent[1,:]) * pix2mm
x_pos = cent[0,:] * pix2mm
x_vel = np.zeros_like(x_pos)
y_vel = np.zeros_like(y_pos)
x_vel[0] = (x_pos[1] - x_pos[0]) / dt
x_vel[1] = (x_pos[2] - x_pos[1]) / dt
x_vel[-2] = (x_pos[-2] - x_pos[-3]) / dt
x_vel[-1] = (x_pos[-1] - x_pos[-2]) / dt
y_vel[0] = (y_pos[1] - y_pos[0]) / dt
y_vel[1] = (y_pos[2] - y_pos[1]) / dt
y_vel[-2] = (y_pos[-2] - y_pos[-3]) / dt
y_vel[-1] = (y_pos[-1] - y_pos[-2]) / dt
for ii in tqdm.tqdm(range(2,nframes-2)):
    x_vel[ii] = (-x_pos[ii+2] + 8*x_pos[ii+1] - 8*x_pos[ii-1] + x_pos[ii-2]) / (12*dt)
    y_vel[ii] = (-y_pos[ii+2] + 8*y_pos[ii+1] - 8*y_pos[ii-1] + y_pos[ii-2]) / (12*dt)

# vel_mag = np.sqrt(x_vel**2 + y_vel**2)
vel_mag = np.sqrt(y_vel**2)
time = (np.ones_like(vel_mag) * dt).cumsum()
df = pd.DataFrame({'time': time, 'v': vel_mag, 'y': y_pos})
df['v_smooth'] = df.v.interpolate(limit=11).rolling(window=31, min_periods=11, center=True).mean()
df['particleId'] = np.zeros_like(df.v.values)
ii = 0
particleId = 0
start_list = np.where(np.isfinite(df.v_smooth.values))[0]
while ii < df.shape[0]:
    if np.isfinite(df.v_smooth.values[ii]):
        df['particleId'][ii] = particleId
        ii += 1
    else:
        particleId += 1
        ilist = start_list[start_list > ii]
        if len(ilist) > 0:
            ii = ilist[0]
        else:
            break

count = 0
plt.figure(count)
dfs = []
for name, group in df.groupby('particleId'):
    if (group.v.std() < 200) & (group.v.mean() > 100):
        count += 1
        group.v_smooth[~np.isfinite(group.y)] = np.nan
        group.v_smooth[-group.y<90] = np.nan
        group.v_smooth[-group.y>160] = np.nan
        # plt.plot(group.time.values - group.time.values[0], group.v_smooth.values, '-')
        plt.plot(-group.y, group.v_smooth.values, '-')
        dfs.append(group)
plt.ylim(0, 500)
plt.xlabel('vertical distance (mm)')
plt.ylabel('velocity (mm/s)')
v_mean = pd.concat(dfs).groupby('particleId').mean().v_smooth
v_std = pd.concat(dfs).groupby('particleId').std().v_smooth
plt.title('mean v: %g +/- %g' % (v_mean.mean(), v_std.mean()))
plt.savefig('terminal_settling_tracks_%s.pdf'%name11)
df = pd.DataFrame(dict(vel = v_mean, vel_e = v_std))
df.to_csv('%s_settling_velocities.csv'%name11)

#%%
fps = 60
num_images = nframes
num_images = 5000
duration = num_images/fps

def make_frame(t):
    img = vid[int(t*fps)]
    img = img[:,150:900]
    img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    img = cv.medianBlur(img, 7)
    img = cv.adaptiveThreshold(img, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 21, -11)
    img = cv.medianBlur(img, 5)
    kernel = np.ones((5,5),np.uint8)
    img = cv.morphologyEx(img, cv.MORPH_CLOSE, kernel, iterations=5)
    img[img>0] = 1
    img = img[10:-10,:]
    img = cv.cvtColor(img, cv.COLOR_GRAY2RGB)
    return img

animation = mpy.VideoClip(make_frame, duration=duration)
animation.write_videofile('%s_converted.mp4'%vidname, fps=fps) # export as video

# %%
